Implementing the builtin modeling language

Author

McCoy Reynolds Becker

Published

December 7, 2022

Abstract
This notebook covers the design decisions, implementation ingredients, and interface implementations for the builtin modeling language in GenJAX. It assumes familiarity with the generative function interface, as well as familiarity with JAX’s support for composable interpreters and staging.
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax
from genjax import GenerativeFunction, ChoiceMap, Selection, trace

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=70)

# Reproducibility.
key = jax.random.PRNGKey(314159)

One key property of the generative function interface is that it enables a separation between model and inference code - providing an abstraction layer that facilitates the development of modular model pieces, and then inference pieces that abstract over the implementation of the interface.

Now, implementing the interface on objects, and composing them in various ways (by e.g. specializing the implementation of the interface functions to support any intended composition) is a valid way to construct new generative functions. In fact, this is the pattern which generative function combinators follow - they accept generative functions as input, and produce new generative functions whose implementations are specialized to represent some specific pattern of computation.

Explicitly constructing generative functions using languages of objects, however, can often feel unwieldy. Part of the way that GenJAX (and Gen.jl) alleviates this restriction is by exposing languages which construct generative functions from programs. This drastically increases the expressivity available to the programmer.

In GenJAX, here’s an example of the BuiltinGenerativeFunction language:

@genjax.gen
def model(x):
    y = genjax.trace("y", genjax.Normal)(x, 1.0)
    z = genjax.trace("z", genjax.Normal)(y + x, 1.0)
    return z

When we apply one of the interface functions to this object, we get the associated data types that we expect.

key, tr = model.simulate(key, (1.0,))
tr

BuiltinTrace
├── gen_fn
│   └── BuiltinGenerativeFunction
│       └── source
│           └── <function model>
├── args
│   └── tuple
│       └── (const) 1.0
├── retval
│   └──  f32[]
├── choices
│   └── Trie
│       ├── :y
│       │   └── DistributionTrace
│       │       ├── gen_fn
│       │       │   └── _Normal
│       │       ├── args
│       │       │   └── tuple
│       │       │       ├── (const) 1.0
│       │       │       └── (const) 1.0
│       │       ├── value
│       │       │   └──  f32[]
│       │       └── score
│       │           └──  f32[]
│       └── :z
│           └── DistributionTrace
│               ├── gen_fn
│               │   └── _Normal
│               ├── args
│               │   └── tuple
│               │       ├──  f32[]
│               │       └── (const) 1.0
│               ├── value
│               │   └──  f32[]
│               └── score
│                   └──  f32[]
├── cache
│   └── Trie
└── score
    └──  f32[]

How exactly do we do this? In this notebook, you’re going to find out. You’ll also get a chance to explore some of the capabilities which JAX exposes to library designers. Ideally, you’ll also get a sense of some of the limitations of JAX (and GenJAX) - which are restricted to support programs which are amenable to GPU/TPU acceleration.

The magic of JAX

Let’s examine the generative function object:

model

BuiltinGenerativeFunction
└── source
    └── <function model>

All the decorator genjax.gen does is wrap the function into this object. It holds a reference to the function we defined above.

But clearly, we need to somehow get inside that function - because we’re recording data onto the BuiltinTrace which come from intermediate results of the execution of the function.

That’s where JAX comes in - JAX provides a way to trace pure, numerical Python programs - enabling us to construct program transformations which return new functions that compute different semantics from the original function.1

  • 1 Program tracing is an approach which has its roots in automatic differentiation. If you’re interesting in this technique, I cannot recommend Autodidax: JAX core from scratch enough. It will introduce you to enough interesting PL ideas to keep you occupied for months, if not years.

  • Let’s utilize one of JAX’s interpreters to construct an intermediate representation of the function which our generative function object holds reference to:

    jaxpr = jax.make_jaxpr(model.source)(1.0)
    jaxpr
    { lambda ; a:f32[]. let
        b:key<fry>[] = random_seed[impl=fry] 0
        _:u32[2] = random_unwrap b
        c:f32[] = trace[addr=y gen_fn=_Normal() tree_in=PyTreeDef((*, *))] a 1.0
        d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
        e:f32[] = add c d
        f:key<fry>[] = random_seed[impl=fry] 0
        _:u32[2] = random_unwrap f
        g:f32[] = trace[addr=z gen_fn=_Normal() tree_in=PyTreeDef((*, *))] e 1.0
      in (g,) }
    

    So jax.make_jaxpr takes a function f :: A -> B and returns a function f :: A -> Jaxpr, where Jaxpr is the program representation above.

    When we run this function using Python’s interpreter, JAX lifts the input to something called a Tracer, JAX keeps an internal stack of interpreters which redirect infix operations on Tracer instances and modify their behavior. Additionally, JAX exposes new primitives (like all the NumPy primitives) which wrap a function called bind. bind takes in Tracer arguments, looks through them (and the interpreter stack), selects the interpreter which should handle the call - and then the interpreter is allowed to process_primitive - invoking the semantics which the interpreter defines for that primitive.

    jax.make_jaxpr uses the above process to walk the program, and construct the above intermediate representation.

    Now, the point of having this representation is that we can transform it further! We can lower it to other representations (including things like XLA - the linear algebra accelerator that JAX utilizes to go high performance). We could also write another interpreter which walks this representation, invokes other primitives with bind, etc - deferring further transformation to the next interpreter in line.

    This (admittedly rough description) above is the secret behind JAX’s compositional transformations.

    New semantics via program transformations

    Let’s examine the representation once more.

    jaxpr
    { lambda ; a:f32[]. let
        b:key<fry>[] = random_seed[impl=fry] 0
        _:u32[2] = random_unwrap b
        c:f32[] = trace[addr=y gen_fn=_Normal() tree_in=PyTreeDef((*, *))] a 1.0
        d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
        e:f32[] = add c d
        f:key<fry>[] = random_seed[impl=fry] 0
        _:u32[2] = random_unwrap f
        g:f32[] = trace[addr=z gen_fn=_Normal() tree_in=PyTreeDef((*, *))] e 1.0
      in (g,) }
    

    You’ll notice that there is an intrinsic called trace here - which looks suspiciously similar to genjax.trace above.

    trace is a custom primitive that GenJAX defines - by defining a new primitive, we can place a stub in the intermediate representation, which we can further transform to implement the semantics we wish to express.

    A high level view

    Now, we need to transform it! Here’s where some serious design decisions enter into the picture.

    One thing you might notice about the Jaxpr above is that the the arity of the function is fixed, and so is the arity of the return value. But when we call simulate on our model - we get out something which is not a h :: f32[] (it’s actually a jax.Pytree with a lot more data - so we’d expect a lot more return values in the Jaxpr2.

  • 2 JAX flattens/unflattens Pytree instances on each side of the IR boundary - the IR is strongly typed, but only natively supports a few base types, and a few composite array types.

  • What gives?

    Here’s where JAX’s support for compositional application of interpreters comes into play.

    Instead of attempting to modify the IR above to change the arity of everything (a process which the authors expect would be quite painful, and buggy) - we can write another interpreter which walks the IR and evaluates it, but that interpreter can keep track of the state that we want to put into the BuiltinTrace at the end of the interface invocation.

    Then, we can stage out that interpreter to support JIT compilation, etc. I’ll describe the process below in pseudo-types:

    We start with f :: A -> B, and we stage it to get a new function f' :: Type[A] -> Jaxpr, then we write an interpreter I with signature I :: (Jaxpr, A) -> (B, State). The application of I itself can also be staged.

    So this is really nice - we don’t have to munge the IR manually, we just get to write an interpreter to do the transformation for us. That’s the power that JAX provides for us!

    Interpreter design decisions

    With the high-level view in mind, we’ll examine two of the interface implementations. The first is simulate - likely the easiest implementation to understand3. The second is update.

  • 3 For this notebook, we’re going to ignore the inference math that we wish to support!

  • Now, in GenJAX, the interpreter is written to be re-usable for each of the interface functions. Because we’ve chosen to re-use the interpreter (and parametrize the transformation semantics by configuring the interpreter in other ways – besides the implementation), you’re going to see some complexity right out the gate.

    The reason why this complexity is there is because we wish to expose incremental computing optimizations in update. To support this customization, the interpreter can best be described as a propagation interpreter - similar to Julia’s abstract interpretation machinery (if you’re familiar). A propagation interpreter treats the Jaxpr as an undirected graph - and performs interpretation by iterating until a fixpoint condition is satisfied.

    The high level pattern from the previous section is still true! But if you’ve written interpreters for something like Structure and Interpretation of Computer Programs before - this interpreter might be a slight shock to the system.

    Here’s a boiled down form of the simulate_transform:

    def simulate_transform(f, **kwargs):
        def _inner(key, args):
            # Step 1: stage out the function to a `Jaxpr`.
            closed_jaxpr, (flat_args, in_tree, out_tree) = stage(f)(
                key, *args, **kwargs
            )
            jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
    
            # Step 2: create a `Simulate` instance, which we parametrize
            # the propagation interpreter with.
            #
            # `Bare` is an instance of something called a `Cell` - the
            # objects which the propagation interpreter reasons about.
            handler = Simulate()
            final_env, ret_state = propagate(
                Bare,
                bare_propagation_rules,
                jaxpr,
                [Bare.new(v) for v in consts],
                list(map(Bare.new, flat_args)),
                [Bare.unknown(var.aval) for var in jaxpr.outvars],
                handler=handler,
            )
    
            # Step 3: when the interpreter finishes, we read the values
            # out of its environment.
            flat_out = safe_map(final_env.read, jaxpr.outvars)
            flat_out = map(lambda v: v.get_val(), flat_out)
            key_and_returns = jtu.tree_unflatten(out_tree, flat_out)
            key, *retvals = key_and_returns
            retvals = tuple(retvals)
    
            # Here's the handler state - remember the signature from
            # above `I :: (Jaxpr, A) -> (B, State)`, these fields
            # below are the `State`.
            score = handler.score
            chm = handler.choice_state
            cache = handler.cache_state
    
            # This returns all the things which we want to put
            # into `BuiltinTrace`.
            return key, (f, args, retvals, chm, score), cache
    
        return _inner

    And, just to show you that this is the key behind how we implement simulate, I’ve copied the BuiltinGenerativeFunction class method for simulate below:

    def simulate(self, key, args, **kwargs):
        assert isinstance(args, Tuple)
        key, (f, args, r, chm, score), cache = simulate_transform(
            self.source, **kwargs
        )(key, args)
        return key, BuiltinTrace(self, args, r, chm, cache, score)

    We’ll discuss propagate in a moment - but a few high-level things.

    Note that the simulate method can be staged out / used with JAX’s interfaces:

    jitted = jax.jit(model.simulate)
    key, tr = jitted(key, (1.0,))
    tr
    
    
    BuiltinTrace
    ├── gen_fn
    │   └── BuiltinGenerativeFunction
    │       └── source
    │           └── <function model>
    ├── args
    │   └── tuple
    │       └──  f32[]
    ├── retval
    │   └──  f32[]
    ├── choices
    │   └── Trie
    │       ├── :y
    │       │   └── DistributionTrace
    │       │       ├── gen_fn
    │       │       │   └── _Normal
    │       │       ├── args
    │       │       │   └── tuple
    │       │       │       ├──  f32[]
    │       │       │       └──  f32[]
    │       │       ├── value
    │       │       │   └──  f32[]
    │       │       └── score
    │       │           └──  f32[]
    │       └── :z
    │           └── DistributionTrace
    │               ├── gen_fn
    │               │   └── _Normal
    │               ├── args
    │               │   └── tuple
    │               │       ├──  f32[]
    │               │       └──  f32[]
    │               ├── value
    │               │   └──  f32[]
    │               └── score
    │                   └──  f32[]
    ├── cache
    │   └── Trie
    └── score
        └──  f32[]
    

    That’s because simulate_transform and the interpreter implementation itself for propagate are all JAX traceable.

    The only difference between the BuiltinTrace which we first generated at the top of the notebook and this one is that jax.jit will lift the 1.0 argument to a Tracer type, versus the non-jitted interpreter which just uses the Python float value.

    And again, we can also stage out our simulate implementation and get a Jaxpr back:

    jax.make_jaxpr(model.simulate)(key, (1.0,))
    { lambda ; a:u32[2] b:f32[]. let
        c:key<fry>[] = random_seed[impl=fry] 0
        _:u32[2] = random_unwrap c
        d:key<fry>[] = random_seed[impl=fry] 0
        _:u32[2] = random_unwrap d
        e:key<fry>[] = random_wrap[impl=fry] a
        f:key<fry>[2] = random_split[count=2] e
        g:u32[2,2] = random_unwrap f
        h:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] g
        i:u32[2] = squeeze[dimensions=(0,)] h
        j:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] g
        k:u32[2] = squeeze[dimensions=(0,)] j
        l:key<fry>[] = random_wrap[impl=fry] k
        m:u32[] = random_bits[bit_width=32 shape=()] l
        n:u32[] = shift_right_logical m 9
        o:u32[] = or n 1065353216
        p:f32[] = bitcast_convert_type[new_dtype=float32] o
        q:f32[] = sub p 1.0
        r:f32[] = sub 1.0 -0.9999999403953552
        s:f32[] = mul q r
        t:f32[] = add s -0.9999999403953552
        u:f32[] = reshape[dimensions=None new_sizes=()] t
        v:f32[] = max -0.9999999403953552 u
        w:f32[] = erf_inv v
        x:f32[] = mul 1.4142135381698608 w
        y:f32[] = mul 1.0 x
        z:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
        ba:f32[] = add z y
        bb:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
        bc:f32[] = sub ba bb
        bd:f32[] = div bc 1.0
        be:f32[] = abs bd
        bf:f32[] = integer_pow[y=2] be
        bg:f32[] = log 6.283185307179586
        bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] bg
        bi:f32[] = add bf bh
        bj:f32[] = mul -1.0 bi
        bk:f32[] = log 1.0
        bl:f32[] = sub 2.0 bk
        bm:f32[] = convert_element_type[new_dtype=float32 weak_type=False] bl
        bn:f32[] = div bj bm
        bo:f32[] = reduce_sum[axes=()] bn
        bp:f32[] = add 0.0 bo
        bq:f32[] = add ba b
        br:key<fry>[] = random_wrap[impl=fry] i
        bs:key<fry>[2] = random_split[count=2] br
        bt:u32[2,2] = random_unwrap bs
        bu:u32[1,2] = slice[
          limit_indices=(1, 2)
          start_indices=(0, 0)
          strides=(1, 1)
        ] bt
        bv:u32[2] = squeeze[dimensions=(0,)] bu
        bw:u32[1,2] = slice[
          limit_indices=(2, 2)
          start_indices=(1, 0)
          strides=(1, 1)
        ] bt
        bx:u32[2] = squeeze[dimensions=(0,)] bw
        by:key<fry>[] = random_wrap[impl=fry] bx
        bz:u32[] = random_bits[bit_width=32 shape=()] by
        ca:u32[] = shift_right_logical bz 9
        cb:u32[] = or ca 1065353216
        cc:f32[] = bitcast_convert_type[new_dtype=float32] cb
        cd:f32[] = sub cc 1.0
        ce:f32[] = sub 1.0 -0.9999999403953552
        cf:f32[] = mul cd ce
        cg:f32[] = add cf -0.9999999403953552
        ch:f32[] = reshape[dimensions=None new_sizes=()] cg
        ci:f32[] = max -0.9999999403953552 ch
        cj:f32[] = erf_inv ci
        ck:f32[] = mul 1.4142135381698608 cj
        cl:f32[] = mul 1.0 ck
        cm:f32[] = add bq cl
        cn:f32[] = sub cm bq
        co:f32[] = div cn 1.0
        cp:f32[] = abs co
        cq:f32[] = integer_pow[y=2] cp
        cr:f32[] = log 6.283185307179586
        cs:f32[] = convert_element_type[new_dtype=float32 weak_type=False] cr
        ct:f32[] = add cq cs
        cu:f32[] = mul -1.0 ct
        cv:f32[] = log 1.0
        cw:f32[] = sub 2.0 cv
        cx:f32[] = convert_element_type[new_dtype=float32 weak_type=False] cw
        cy:f32[] = div cu cx
        cz:f32[] = reduce_sum[axes=()] cy
        da:f32[] = add bp cz
      in (bv, b, cm, b, 1.0, ba, bo, bq, 1.0, cm, cz, da) }
    

    Giving us our pure, array math code. You can’t help but admit that that’s pretty elegant!

    How does propagate work?

    Now, in this section - we’re going to talk about the nitty gritty of propagate itself. What exactly is this interpreter doing? Let’s examine the context surrounding the call to propagate:

    def simulate_transform(f, **kwargs):
        def _inner(key, args):
            closed_jaxpr, (flat_args, in_tree, out_tree) = stage(f)(
                key, *args, **kwargs
            )
            jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
            handler = Simulate()
            final_env, ret_state = propagate(
                # A lattice type
                Bare,
                
                # Lattice propagation rules
                bare_propagation_rules,
                
                # The Jaxpr which we wish to interpret
                jaxpr,
                
                # Trace-time constants
                [Bare.new(v) for v in consts],
                
                # Input cells
                list(map(Bare.new, flat_args)),
                
                # Output cells
                [Bare.unknown(var.aval) for var in jaxpr.outvars],
                
                # How we handle `trace`.
                handler=handler,
            )
            ...
    
        return _inner

    First, we stage our model function into a Jaxpr - when we perform the staging process, everything (e.g. custom datatypes which are Pytree implementors) gets flattened out to array leaves.

    After we stage, we collect all the data which we want to use to initialize our interpreter’s environment with - but we encounter our first bit of complexity.

    What is Bare? And what is a Cell? Let’s start with the latter question: a Cell is an abstract type which represents a lattice value.

    To understand what a lattice value is - it’s worth gaining a high-level picture of what propagate attempts to do. propagate is an interpreter based on mixed concrete/abstract interpretation - it treats the Jaxpr as a graph - where the operations are nodes in the graph, and the SSA values (e.g. the named registers like ci, cj, etc) are edges.

    The interpreter will iterate over the graph - attempting to update information about the edges by applying propagation rules (hence the name, propagate) which we define (bare_propagation_rules, above).

    A propagation rule accepts a list of input cells (the SSA edges which flow into the operation) and a list of output cells. It returns a new modified list of input cells, and a new modified list of output cells, as well as a state value (in this notebook, we won’t discuss the state value - it’s unneeded for the interfaces we will describe).

    The way the interpreter works is that it keeps a queue of nodes and an environment which maps SSA values to lattice values. We pop a node off the queue, grab the existing lattice values for input SSA values and output SSA values, attempt to update them using a propagation rule, and then store the update in the environment. In addition, after we attempt to update the cells - we determine if the update has changed the information level of any of the cells. If the information level has changed for any cell (as measured using the partial order on lattice values), we add any nodes which the SSA value associated with that cell flows into back onto the queue.

    This process describes an iterative algorithm which attempts to compute an information fixpoint - defined by a state transition function (which operates on the state of all cells in the Jaxpr - the environment) which we get to customize using propagation rules.

    I’m not going to inline any of the implementation of this interpreter into this notebook. I’ll refer the reader to the implementation of the interpreter.4

  • 4 Note that the ideas behind this interpreter are quite widespread - but the original implementation (which the GenJAX authors modified) came from Oryx, and that implementation initially came from Roy Frostig (as far as we can tell).

  • What happens in simulate?

    Great - so how do we utilize this interpreter idea to implement the simulate_transform described above?